package com.luis.toolsuite.nbc;

import org.apache.commons.lang3.StringUtils;

import java.io.*;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.*;

public class NbcMain {

    private static final List<String> ignoreList;
    private static final double FACTOR = 0.1;
    private static final BigDecimal APPEND = BigDecimal.valueOf(0.00001);
    private static final int  RESULT_SCALE = 4;
    static {
        ignoreList = new ArrayList<>();
        ignoreList.add("测试");
        ignoreList.add("123");
        ignoreList.add("test");
        ignoreList.add("TEST");
    }

    public static void main(String[] args) throws IOException {
        long versionInfo = System.currentTimeMillis();
        //训练
        Map<String,Integer> trainResult = train();
        //数据训练结果
        flushResult(trainResult, "训练结果_"+versionInfo+".csv");
        //评估
        Map<String, BigDecimal> probability = calcProbility(trainResult);
        Map<String,Integer> testResult = evaluate(probability);
        //输出评估结果
        flushResult(testResult,"计算结果_"+versionInfo+".csv");
    }

    private static Map<String,Integer> train() throws IOException {
        File f = new File("/tmp/train_subject.txt");
        BufferedReader br = new BufferedReader(new FileReader(f));
        Map<String,Integer> trainResult = new HashMap<>();
        br.lines().forEach(s->{
            if(StringUtils.containsAny(s,ignoreList.toArray(new String[0]))
            && Math.random() > FACTOR){
                trainResult.put(s, 1);//测试邮件
            }else{
                trainResult.put(s, 0);//非测试邮件
            }

        });
        br.close();
        return trainResult;
    }

    //计算先验概率
    private static Map<String, BigDecimal> calcProbility(Map<String,Integer> trainResult){
        Map<String, BigDecimal> resultMap = new HashMap<>();
        int total = trainResult.size();
        long testCount = trainResult.values().stream().filter(i-> i==1).count();
        //计算测试邮件概率P(Test)
        BigDecimal p1 = BigDecimal.valueOf(testCount).divide(BigDecimal.valueOf(total),RESULT_SCALE, RoundingMode.HALF_UP);
        resultMap.put("1",p1);
        //计算P(wn)
        Set<String> trainSubject = trainResult.keySet();
        for(String ignore : ignoreList){
            long ignoreCount = trainSubject.stream().filter(s-> s.indexOf(ignore) > 0).count();
            BigDecimal ignoreProbability = BigDecimal.valueOf(ignoreCount).divide(BigDecimal.valueOf(total),RESULT_SCALE, RoundingMode.HALF_UP);
            resultMap.put(ignore,ignoreProbability);
        }
        //计算P(wn|Test)
        Set<String> trainTestSubject = new HashSet<>();
        for(String sub : trainSubject){
            if(trainResult.get(sub) == 1) trainTestSubject.add(sub);
        }
        for(String ignore : ignoreList){
            long ignoreTestCount = trainTestSubject.stream().filter(s-> s.indexOf(ignore) > 0).count();
            BigDecimal ignoreTestProbability = BigDecimal.valueOf(ignoreTestCount).divide(BigDecimal.valueOf(testCount),RESULT_SCALE, RoundingMode.HALF_UP);
            resultMap.put("1@"+ignore,ignoreTestProbability);
        }
        return resultMap;
    }

    private static Map<String,Integer> evaluate(Map<String, BigDecimal> preProbability) throws IOException{
        File f = new File("/tmp/test_subject.txt");
        BufferedReader br = new BufferedReader(new FileReader(f));
        Map<String,Integer> result = new HashMap<>();
        br.lines().forEach(s->{
            //计算测试会议的概率
            BigDecimal result0 = evalSingle(s,preProbability);
            result.put(s, result0.compareTo(BigDecimal.valueOf(FACTOR)) > 0?1:0);
        });
        return result;
    }

    private static BigDecimal evalSingle(String subject, Map<String, BigDecimal> preProbability){
        //测试概率 = P(词语1 | 测试) P(词语 2 | 测试) P(测试)  /P(词语1)P(词语 2)

        if(StringUtils.containsAny(subject, ignoreList.toArray(new String[0]))){
            BigDecimal result = BigDecimal.ONE;
            for(String ignore : ignoreList){
                if(StringUtils.contains(subject,ignore)){
                    BigDecimal p1 = preProbability.get("1@"+ignore).add(APPEND);
                    BigDecimal p2 = preProbability.get(ignore).add(APPEND);
                    result = result.multiply(p1.divide(p2,RESULT_SCALE,RoundingMode.HALF_UP));
                }
            }
            result = result.multiply(preProbability.get("1"));
            return result;
        }
        return BigDecimal.ZERO;
    }

    private static void flushResult(Map<String,Integer> result, String fileName) throws IOException {
        File f = new File("/tmp/"+fileName);
        if(!f.exists()) f.createNewFile();
        BufferedWriter bw = new BufferedWriter(new FileWriter(f));
        //写bom头
        bw.write(new String(new byte[] { (byte) 0xEF, (byte) 0xBB,(byte) 0xBF }));
        for(String key : result.keySet()){
            String content = key +"," +(result.get(key) == 1?"测试会议":"正式会议") +"\r\n";
            bw.write(content);
        }
        bw.flush();
        bw.close();
    }

}
